use std::fs::File;
use std::io::{Write};
use ndarray::array;
use linfa::prelude::*;
use linfa_trees::{DecisionTree, SplitQuality};
use ndarray::prelude::*;

fn main() {
    let original_data: Array2<f32> = array!(
        [1.,1.,1000.,1.,10.],
        [1.,0.,0.,1.,6.],
        [1.,0.,0.,1.,6.],
        [1.,0.,0.,1.,6.],
        [1.,0.,0.,1.,6.],
        [1.,0.,800.,1.,8.],
        [1.,0.,0.,0.,0.],
        [1.,1.,0.,1.,9.],
        [1.,1.,0.,1.,8.],
        [1.,0.,800.,0.,8.],
        [1.,1.,0.,1.,8.],
        [1.,1.,500.,0.,8.],
        [1.,0.,50.,0.,3.],
        [1.,1.,50.,0.,4.],
        [1.,0.,50.,0.,3.],
    );

    // 是否看电视, 是否撸猫, 写了多少行rust代码, 是否吃了披萨, 当天快乐指数(0~10)
    // 快乐指数是待预测值
    // 特征名称
    let feature_names = vec!["Watched TV", "Pet Cat", "Rust LOC", "Ate Pizza"];

    // 特征值数量
    // original_data的列数 - 1(快乐指数)
    let num_features = original_data.len_of(Axis(1)) - 1;
    // 拿到original_data的 [所有行, 0~original_data列]
    let features = original_data.slice(s![..,0..num_features]).to_owned();
    // 快乐指数数组(训练数据的目标数据部分)
    // original_data的num_features列
    let labels = original_data.column(num_features).to_owned();

    // 建立数据集
    // (特征数组, 目标数组)
    let linfa_dataset = Dataset::new(features, labels)
        // 将快乐指数(i32)转换为字符串(表示心情)
        .map_targets(|x| match x.to_owned() as i32 {
            i32::MIN..=4 => "Sad",
            5..=7 => "Ok",
            8..=i32::MAX => "Happy",
        })
        // 指定特征名称
        .with_feature_names(feature_names);

    // 决策树模型
    let model = DecisionTree::params() // 预先准备的参数
        .split_quality(SplitQuality::Gini) // 分裂算法?
        .fit(&linfa_dataset) // 训练
        .unwrap();

    // 输出文件
    File::create("dt.tex")
        .unwrap()
        .write_all(model.export_to_tikz().with_legend().to_string().as_bytes())
        .unwrap();
}
